{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp models.timesnet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TimesNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The TimesNet univariate model tackles the challenge of modeling multiple intraperiod and interperiod temporal variations.\n",
    "\n",
    "The architecture has the following distinctive features:\n",
    "- An embedding layer that maps the input sequence into a latent space.\n",
    "- Transformation of 1D time seires into 2D tensors, based on periods found by FFT.\n",
    "- A convolutional Inception block that captures temporal variations at different scales and between periods.\n",
    "\n",
    "**References**<br>\n",
    "- [Haixu Wu and Tengge Hu and Yong Liu and Hang Zhou and Jianmin Wang and Mingsheng Long. TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis](https://openreview.net/pdf?id=ju_Uqw384Oq)\n",
    "- Based on the implementation in https://github.com/thuml/Time-Series-Library (license: https://github.com/thuml/Time-Series-Library/blob/main/LICENSE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Figure 1. TimesNet Architecture.](imgs_models/timesnet.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from typing import Optional\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.fft\n",
    "\n",
    "from neuralforecast.common._modules import DataEmbedding\n",
    "from neuralforecast.common._base_windows import BaseWindows\n",
    "\n",
    "from neuralforecast.losses.pytorch import MAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastcore.test import test_eq\n",
    "from nbdev.showdoc import show_doc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Auxiliary Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Inception_Block_V1(nn.Module):\n",
    "    def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):\n",
    "        super(Inception_Block_V1, self).__init__()\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.num_kernels = num_kernels\n",
    "        kernels = []\n",
    "        for i in range(self.num_kernels):\n",
    "            kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))\n",
    "        self.kernels = nn.ModuleList(kernels)\n",
    "        if init_weight:\n",
    "            self._initialize_weights()\n",
    "\n",
    "    def _initialize_weights(self):\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
    "                if m.bias is not None:\n",
    "                    nn.init.constant_(m.bias, 0)\n",
    "\n",
    "    def forward(self, x):\n",
    "        res_list = []\n",
    "        for i in range(self.num_kernels):\n",
    "            res_list.append(self.kernels[i](x))\n",
    "        res = torch.stack(res_list, dim=-1).mean(-1)\n",
    "        return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def FFT_for_Period(x, k=2):\n",
    "    # [B, T, C]\n",
    "    xf = torch.fft.rfft(x, dim=1)\n",
    "    # find period by amplitudes\n",
    "    frequency_list = abs(xf).mean(0).mean(-1)\n",
    "    frequency_list[0] = 0\n",
    "    _, top_list = torch.topk(frequency_list, k)\n",
    "    top_list = top_list.detach().cpu().numpy()\n",
    "    period = x.shape[1] // top_list\n",
    "    return period, abs(xf).mean(-1)[:, top_list]\n",
    "\n",
    "class TimesBlock(nn.Module):\n",
    "    def __init__(self, input_size, h, k, hidden_size, conv_hidden_size, num_kernels):\n",
    "        super(TimesBlock, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.h = h\n",
    "        self.k = k\n",
    "        # parameter-efficient design\n",
    "        self.conv = nn.Sequential(\n",
    "            Inception_Block_V1(hidden_size, conv_hidden_size,\n",
    "                               num_kernels=num_kernels),\n",
    "            nn.GELU(),\n",
    "            Inception_Block_V1(conv_hidden_size, hidden_size,\n",
    "                               num_kernels=num_kernels)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        B, T, N = x.size()\n",
    "        period_list, period_weight = FFT_for_Period(x, self.k)\n",
    "\n",
    "        res = []\n",
    "        for i in range(self.k):\n",
    "            period = period_list[i]\n",
    "            # padding\n",
    "            if (self.input_size + self.h) % period != 0:\n",
    "                length = (\n",
    "                                 ((self.input_size + self.h) // period) + 1) * period\n",
    "                padding = torch.zeros([x.shape[0], (length - (self.input_size + self.h)), x.shape[2]]).to(x.device)\n",
    "                out = torch.cat([x, padding], dim=1)\n",
    "            else:\n",
    "                length = (self.input_size + self.h)\n",
    "                out = x\n",
    "            # reshape\n",
    "            out = out.reshape(B, length // period, period,\n",
    "                              N).permute(0, 3, 1, 2).contiguous()\n",
    "            # 2D conv: from 1d Variation to 2d Variation\n",
    "            out = self.conv(out)\n",
    "            # reshape back\n",
    "            out = out.permute(0, 2, 3, 1).reshape(B, -1, N)\n",
    "            res.append(out[:, :(self.input_size + self.h), :])\n",
    "        res = torch.stack(res, dim=-1)\n",
    "        # adaptive aggregation\n",
    "        period_weight = F.softmax(period_weight, dim=1)\n",
    "        period_weight = period_weight.unsqueeze(\n",
    "            1).unsqueeze(1).repeat(1, T, N, 1)\n",
    "        res = torch.sum(res * period_weight, -1)\n",
    "        # residual connection\n",
    "        res = res + x\n",
    "        return res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. TimesNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TimesNet(BaseWindows):\n",
    "    \"\"\" TimesNet\n",
    "\n",
    "    The TimesNet univariate model tackles the challenge of modeling multiple intraperiod and interperiod temporal variations.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    h : int\n",
    "        Forecast horizon.\n",
    "    input_size : int\n",
    "        Length of input window (lags).\n",
    "    futr_exog_list : list of str, optional (default=None)\n",
    "        Future exogenous columns.\n",
    "    hist_exog_list : list of str, optional (default=None)\n",
    "        Historic exogenous columns.\n",
    "    stat_exog_list : list of str, optional (default=None)\n",
    "        Static exogenous columns.\n",
    "    exclude_insample_y : bool (default=False)\n",
    "        The model skips the autoregressive features y[t-input_size:t] if True\n",
    "    hidden_size : int (default=64)\n",
    "        Size of embedding for embedding and encoders.\n",
    "    dropout : float between [0, 1) (default=0.1)\n",
    "        Dropout for embeddings.\n",
    "\tconv_hidden_size: int (default=64)\n",
    "        Channels of the Inception block.\n",
    "    top_k: int (default=5)\n",
    "        Number of periods.\n",
    "    num_kernels: int (default=6)\n",
    "        Number of kernels for the Inception block.\n",
    "    encoder_layers : int, (default=2)\n",
    "        Number of encoder layers.\n",
    "    loss: PyTorch module (default=MAE())\n",
    "        Instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).\n",
    "    valid_loss: PyTorch module (default=None, uses loss)\n",
    "        Instantiated validation loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).\n",
    "    max_steps: int (default=1000)\n",
    "        Maximum number of training steps.\n",
    "    learning_rate : float (default=1e-4)\n",
    "        Learning rate.\n",
    "    num_lr_decays`: int (default=-1)\n",
    "        Number of learning rate decays, evenly distributed across max_steps. If -1, no learning rate decay is performed.\n",
    "    early_stop_patience_steps : int (default=-1)\n",
    "        Number of validation iterations before early stopping. If -1, no early stopping is performed.\n",
    "    val_check_steps : int (default=100)\n",
    "        Number of training steps between every validation loss check.\n",
    "    batch_size : int (default=32)\n",
    "        Number of different series in each batch.\n",
    "    valid_batch_size : int (default=None)\n",
    "        Number of different series in each validation and test batch, if None uses batch_size.\n",
    "    windows_batch_size : int (default=64)\n",
    "        Number of windows to sample in each training batch.\n",
    "    inference_windows_batch_size : int (default=256)\n",
    "        Number of windows to sample in each inference batch.\n",
    "    start_padding_enabled : bool (default=False)\n",
    "        If True, the model will pad the time series with zeros at the beginning by input size.\n",
    "    scaler_type : str (default='standard')\n",
    "        Type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n",
    "    random_seed : int (default=1)\n",
    "        Random_seed for pytorch initializer and numpy generators.\n",
    "    num_workers_loader : int (default=0)\n",
    "        Workers to be used by `TimeSeriesDataLoader`.\n",
    "    drop_last_loader : bool (default=False)\n",
    "        If True `TimeSeriesDataLoader` drops last non-full batch.\n",
    "    **trainer_kwargs\n",
    "        Keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer)\n",
    "\n",
    "\tReferences\n",
    "\t----------\n",
    "    Haixu Wu and Tengge Hu and Yong Liu and Hang Zhou and Jianmin Wang and Mingsheng Long. TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis. https://openreview.net/pdf?id=ju_Uqw384Oq\n",
    "    \"\"\"\n",
    "    # Class attributes\n",
    "    SAMPLING_TYPE = 'windows'\n",
    "    \n",
    "    def __init__(self,\n",
    "                 h: int, \n",
    "                 input_size: int,\n",
    "                 stat_exog_list = None,\n",
    "                 hist_exog_list = None,\n",
    "                 futr_exog_list = None,\n",
    "                 exclude_insample_y = False,\n",
    "                 hidden_size: int = 64, \n",
    "                 dropout: float = 0.1,\n",
    "                 conv_hidden_size: int = 64,\n",
    "                 top_k: int = 5,\n",
    "                 num_kernels: int = 6,\n",
    "                 encoder_layers: int = 2,\n",
    "                 loss = MAE(),\n",
    "                 valid_loss = None,\n",
    "                 max_steps: int = 1000,\n",
    "                 learning_rate: float = 1e-4,\n",
    "                 num_lr_decays: int = -1,\n",
    "                 early_stop_patience_steps: int =-1,\n",
    "                 val_check_steps: int = 100,\n",
    "                 batch_size: int = 32,\n",
    "                 valid_batch_size: Optional[int] = None,\n",
    "                 windows_batch_size = 64,\n",
    "                 inference_windows_batch_size = 256,\n",
    "                 start_padding_enabled = False,\n",
    "                 step_size: int = 1,\n",
    "                 scaler_type: str = 'standard',\n",
    "                 random_seed: int = 1,\n",
    "                 num_workers_loader: int = 0,\n",
    "                 drop_last_loader: bool = False,\n",
    "                 **trainer_kwargs):\n",
    "        super(TimesNet, self).__init__(h=h,\n",
    "                                       input_size=input_size,\n",
    "                                       hist_exog_list=hist_exog_list,\n",
    "                                       stat_exog_list=stat_exog_list,\n",
    "                                       futr_exog_list = futr_exog_list,\n",
    "                                       exclude_insample_y = exclude_insample_y,\n",
    "                                       loss=loss,\n",
    "                                       valid_loss=valid_loss,\n",
    "                                       max_steps=max_steps,\n",
    "                                       learning_rate=learning_rate,\n",
    "                                       num_lr_decays=num_lr_decays,\n",
    "                                       early_stop_patience_steps=early_stop_patience_steps,\n",
    "                                       val_check_steps=val_check_steps,\n",
    "                                       batch_size=batch_size,\n",
    "                                       windows_batch_size=windows_batch_size,\n",
    "                                       valid_batch_size=valid_batch_size,\n",
    "                                       inference_windows_batch_size=inference_windows_batch_size,\n",
    "                                       start_padding_enabled = start_padding_enabled,\n",
    "                                       step_size=step_size,\n",
    "                                       scaler_type=scaler_type,\n",
    "                                       num_workers_loader=num_workers_loader,\n",
    "                                       drop_last_loader=drop_last_loader,\n",
    "                                       random_seed=random_seed,\n",
    "                                       **trainer_kwargs)\n",
    "\n",
    "        # Architecture\n",
    "        self.futr_input_size = len(self.futr_exog_list)\n",
    "        self.hist_input_size = len(self.hist_exog_list)\n",
    "        self.stat_input_size = len(self.stat_exog_list)\n",
    "\n",
    "        if self.stat_input_size > 0:\n",
    "            raise Exception('TimesNet does not support static variables yet')\n",
    "        if self.hist_input_size > 0:\n",
    "            raise Exception('TimesNet does not support historical variables yet')\n",
    "\n",
    "        self.c_out = self.loss.outputsize_multiplier\n",
    "        self.enc_in = 1 \n",
    "        self.dec_in = 1\n",
    "\n",
    "        self.model = nn.ModuleList([TimesBlock(input_size=input_size,\n",
    "                                               h=h,\n",
    "                                               k=top_k,\n",
    "                                               hidden_size=hidden_size,\n",
    "                                               conv_hidden_size=conv_hidden_size,\n",
    "                                               num_kernels=num_kernels)\n",
    "                                    for _ in range(encoder_layers)])\n",
    "\n",
    "        self.enc_embedding = DataEmbedding(c_in=self.enc_in,\n",
    "                                            exog_input_size=self.futr_input_size,\n",
    "                                            hidden_size=hidden_size, \n",
    "                                            pos_embedding=True, # Original implementation uses true\n",
    "                                            dropout=dropout)\n",
    "        self.encoder_layers = encoder_layers\n",
    "        self.layer_norm = nn.LayerNorm(hidden_size)\n",
    "        self.predict_linear = nn.Linear(self.input_size, self.h + self.input_size)\n",
    "        self.projection = nn.Linear(hidden_size, self.c_out, bias=True)\n",
    "\n",
    "    def forward(self, windows_batch):\n",
    "\n",
    "        # Parse windows_batch\n",
    "        insample_y    = windows_batch['insample_y']\n",
    "        #insample_mask = windows_batch['insample_mask']\n",
    "        #hist_exog     = windows_batch['hist_exog']\n",
    "        #stat_exog     = windows_batch['stat_exog']\n",
    "        futr_exog     = windows_batch['futr_exog']\n",
    "\n",
    "        # Parse inputs\n",
    "        insample_y = insample_y.unsqueeze(-1) # [Ws,L,1]\n",
    "        if self.futr_input_size > 0:\n",
    "            x_mark_enc = futr_exog[:,:self.input_size,:]\n",
    "        else:\n",
    "            x_mark_enc = None\n",
    "\n",
    "        # embedding\n",
    "        enc_out = self.enc_embedding(insample_y, x_mark_enc)\n",
    "        enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(0, 2, 1)  # align temporal dimension\n",
    "        # TimesNet\n",
    "        for i in range(self.encoder_layers):\n",
    "            enc_out = self.layer_norm(self.model[i](enc_out))\n",
    "        # porject back\n",
    "        dec_out = self.projection(enc_out)\n",
    "\n",
    "        forecast = self.loss.domain_map(dec_out[:, -self.h:])\n",
    "        return forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TimesNet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TimesNet.fit, name='TimesNet.fit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TimesNet.predict, name='TimesNet.predict')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pytorch_lightning as pl\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.losses.pytorch import MQLoss, DistributionLoss\n",
    "from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic, augment_calendar_df\n",
    "\n",
    "AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')\n",
    "\n",
    "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
    "Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
    "\n",
    "model = TimesNet(h=12,\n",
    "                 input_size=24,\n",
    "                 hidden_size = 16,\n",
    "                 conv_hidden_size = 32,\n",
    "                 #loss=MAE(),\n",
    "                 #loss=MQLoss(quantiles=[0.2, 0.5, 0.8]),\n",
    "                 loss=DistributionLoss(distribution='Normal', level=[80, 90]),\n",
    "                 futr_exog_list=calendar_cols,\n",
    "                 scaler_type='standard',\n",
    "                 learning_rate=1e-3,\n",
    "                 max_steps=5,\n",
    "                 val_check_steps=50,\n",
    "                 early_stop_patience_steps=2)\n",
    "\n",
    "nf = NeuralForecast(\n",
    "    models=[model],\n",
    "    freq='M'\n",
    ")\n",
    "nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n",
    "forecasts = nf.predict(futr_df=Y_test_df)\n",
    "\n",
    "Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df])\n",
    "\n",
    "if model.loss.is_distribution_output:\n",
    "    plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "    plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "    plt.plot(plot_df['ds'], plot_df['TimesNet-median'], c='blue', label='median')\n",
    "    plt.fill_between(x=plot_df['ds'][-12:], \n",
    "                    y1=plot_df['TimesNet-lo-90'][-12:].values, \n",
    "                    y2=plot_df['TimesNet-hi-90'][-12:].values,\n",
    "                    alpha=0.4, label='level 90')\n",
    "    plt.grid()\n",
    "    plt.legend()\n",
    "    plt.plot()\n",
    "else:\n",
    "    plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "    plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "    plt.plot(plot_df['ds'], plot_df['TimesNet'], c='blue', label='Forecast')\n",
    "    plt.legend()\n",
    "    plt.grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
